{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kCeYA79m1DEX"
   },
   "source": [
    "# Multi-task recommenders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Dk8QEc4sIPMi"
   },
   "source": [
    "In the [basic retrieval tutorial](basic_retrieval) we built a retrieval system using movie watches as positive interaction signals.\n",
    "\n",
    "In many applications, however, there are multiple rich sources of feedback to draw upon. For example, an e-commerce site may record user visits to product pages (abundant, but relatively low signal), image clicks, adding to cart, and, finally, purchases. It may even record post-purchase signals such as reviews and returns.\n",
    "\n",
    "Integrating all these different forms of feedback is critical to building systems that users love to use, and that do not optimize for any one metric at the expense of overall performance.\n",
    "\n",
    "In addition, building a joint model for multiple tasks may produce better results than building a number of task-specific models. This is especially true where some data is abundant (for example, clicks), and some data is sparse (purchases, returns, manual reviews). In those scenarios, a joint model may be able to use representations learned from the abundant task to improve its predictions on the sparse task via a phenomenon known as [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning). For example, [this paper](https://openreview.net/pdf?id=SJxPVcSonN) shows that a model predicting explicit user ratings from sparse user surveys can be substantially improved by adding an auxiliary task that uses abundant click log data.\n",
    "\n",
    "In this tutorial, we are going to build a multi-objective recommender for Movielens, using both implicit (movie watches) and explicit signals (ratings)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ZwrcZeK7x7xI"
   },
   "source": [
    "## Imports\n",
    "\n",
    "\n",
    "Let's first get our imports out of the way.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SZGYDaF-m5wZ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pprint\n",
    "import tempfile\n",
    "\n",
    "from typing import Dict, Text\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BxQ_hy7xPH3N"
   },
   "outputs": [],
   "source": [
    "import tensorflow_recommenders as tfrs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5PAqjR4a1RR4"
   },
   "source": [
    "## Preparing the dataset\n",
    "\n",
    "We're going to use the Movielens 100K dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-ySWtibjm_6a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1mDownloading and preparing dataset movie_lens/1m-ratings/0.1.0 (download: 5.64 MiB, generated: 308.42 MiB, total: 314.06 MiB) to /home/ec2-user/tensorflow_datasets/movie_lens/1m-ratings/0.1.0...\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "615494fe8c654d6ead0d41a875166f09",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21ab1ec5d7b744db8fb9edf39f5d20e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54efdfbfb66448309fde6647a381ee9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c527b6b464c4196a4d90048ed5dc4a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shuffling and writing examples to /home/ec2-user/tensorflow_datasets/movie_lens/1m-ratings/0.1.0.incompleteD6BFGU/movie_lens-train.tfrecord\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "72a1659961d0410080428880cf099738",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1mDataset movie_lens downloaded and prepared to /home/ec2-user/tensorflow_datasets/movie_lens/1m-ratings/0.1.0. Subsequent calls will reuse this data.\u001b[0m\n",
      "\u001b[1mDownloading and preparing dataset movie_lens/1m-movies/0.1.0 (download: 5.64 MiB, generated: 351.12 KiB, total: 5.99 MiB) to /home/ec2-user/tensorflow_datasets/movie_lens/1m-movies/0.1.0...\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3afdfedfde054a678ff558ff6459c497",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d452df436ddb46aea388cc837cd532f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "19c1fe9a2bb546aaa5f0e8e57870b768",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "598058312cf54ef5b7cf8bf4e03a4d06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shuffling and writing examples to /home/ec2-user/tensorflow_datasets/movie_lens/1m-movies/0.1.0.incompleteKX5UKY/movie_lens-train.tfrecord\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18ff2ded6a3c4e0393219c1cd0cdec7c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=3883.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1mDataset movie_lens downloaded and prepared to /home/ec2-user/tensorflow_datasets/movie_lens/1m-movies/0.1.0. Subsequent calls will reuse this data.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "ratings = tfds.load('movie_lens/1m-ratings', split=\"train\")\n",
    "movies = tfds.load('movie_lens/1m-movies', split=\"train\")\n",
    "\n",
    "# Select the basic features.\n",
    "ratings = ratings.map(lambda x: {\n",
    "    \"movie_title\": x[\"movie_title\"],\n",
    "    \"user_id\": x[\"user_id\"],\n",
    "    \"user_rating\": x[\"user_rating\"],\n",
    "})\n",
    "movies = movies.map(lambda x: x[\"movie_title\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JRHorm8W1yf3"
   },
   "source": [
    "And repeat our preparations for building vocabularies and splitting the data into a train and a test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rS0eDfkjnjJL"
   },
   "outputs": [],
   "source": [
    "# Randomly shuffle data and split between train and test.\n",
    "tf.random.set_seed(42)\n",
    "shuffled = ratings.shuffle(1_000_000, seed=42, reshuffle_each_iteration=False)\n",
    "\n",
    "train = shuffled.take(800_000)\n",
    "test = shuffled.skip(800_000).take(200_000)\n",
    "\n",
    "movie_titles = movies.batch(1_000)\n",
    "user_ids = ratings.batch(1_000_000).map(lambda x: x[\"user_id\"])\n",
    "\n",
    "unique_movie_titles = np.unique(np.concatenate(list(movie_titles)))\n",
    "unique_user_ids = np.unique(np.concatenate(list(user_ids)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eCi-seR86qqa"
   },
   "source": [
    "## A multi-task model\n",
    "\n",
    "There are two critical parts to multi-task recommenders:\n",
    "\n",
    "1. They optimize for two or more objectives, and so have two or more losses.\n",
    "2. They share variables between the tasks, allowing for transfer learning.\n",
    "\n",
    "In this tutorial, we will define our models as before, but instead of having  a single task, we will have two tasks: one that predicts ratings, and one that predicts movie watches."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AXHrk_SLzKCM"
   },
   "source": [
    "The user and movie models are as before:\n",
    "\n",
    "```python\n",
    "user_model = tf.keras.Sequential([\n",
    "  tf.keras.layers.experimental.preprocessing.StringLookup(\n",
    "      vocabulary=unique_user_ids),\n",
    "  # We add 2 to account for unknown and mask tokens.\n",
    "  tf.keras.layers.Embedding(len(unique_user_ids) + 2, embedding_dimension)\n",
    "])\n",
    "\n",
    "movie_model = tf.keras.Sequential([\n",
    "  tf.keras.layers.experimental.preprocessing.StringLookup(\n",
    "      vocabulary=unique_movie_titles),\n",
    "  tf.keras.layers.Embedding(len(unique_movie_titles) + 2, embedding_dimension)\n",
    "])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cWCwkE5z8QBe"
   },
   "source": [
    "However, now we will have two tasks. The first is the rating task:\n",
    "\n",
    "```python\n",
    "tfrs.tasks.Ranking(\n",
    "    loss=tf.keras.losses.MeanSquaredError(),\n",
    "    metrics=[tf.keras.metrics.RootMeanSquaredError()],\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xrgQIXEm8UWf"
   },
   "source": [
    "Its goal is to predict the ratings as accurately as possible.\n",
    "\n",
    "The second is the retrieval task:\n",
    "\n",
    "```python\n",
    "tfrs.tasks.Retrieval(\n",
    "    metrics=tfrs.metrics.FactorizedTopK(\n",
    "        candidates=movies.batch(128)\n",
    "    )\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SCNrv7_gakmF"
   },
   "source": [
    "As before, this task's goal is to predict which movies the user will or will not watch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DSWw3xuq8mGh"
   },
   "source": [
    "### Putting it together\n",
    "\n",
    "We put it all together in a model class.\n",
    "\n",
    "The new component here is that - since we have two tasks and two losses - we need to decide on how important each loss is. We can do this by giving each of the losses a weight, and treating these weights as hyperparameters. If we assign a large loss weight to the rating task, our model is going to focus on predicting ratings (but still use some information from the retrieval task); if we assign a large loss weight to the retrieval task, it will focus on retrieval instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YFSkOAMgzU0K"
   },
   "outputs": [],
   "source": [
    "class MovielensModel(tfrs.models.Model):\n",
    "\n",
    "  def __init__(self, rating_weight: float, retrieval_weight: float) -> None:\n",
    "    # We take the loss weights in the constructor: this allows us to instantiate\n",
    "    # several model objects with different loss weights.\n",
    "\n",
    "    super().__init__()\n",
    "\n",
    "    embedding_dimension = 32\n",
    "\n",
    "    # User and movie models.\n",
    "    self.movie_model: tf.keras.layers.Layer = tf.keras.Sequential([\n",
    "      tf.keras.layers.experimental.preprocessing.StringLookup(\n",
    "        vocabulary=unique_movie_titles, mask_token=None),\n",
    "      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)\n",
    "    ])\n",
    "    self.user_model: tf.keras.layers.Layer = tf.keras.Sequential([\n",
    "      tf.keras.layers.experimental.preprocessing.StringLookup(\n",
    "        vocabulary=unique_user_ids, mask_token=None),\n",
    "      tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)\n",
    "    ])\n",
    "\n",
    "    # A small model to take in user and movie embeddings and predict ratings.\n",
    "    # We can make this as complicated as we want as long as we output a scalar\n",
    "    # as our prediction.\n",
    "    self.rating_model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Dense(256, activation=\"relu\"),\n",
    "        tf.keras.layers.Dense(128, activation=\"relu\"),\n",
    "        tf.keras.layers.Dense(1),\n",
    "    ])\n",
    "\n",
    "    # The tasks.\n",
    "    self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(\n",
    "        loss=tf.keras.losses.MeanSquaredError(),\n",
    "        metrics=[tf.keras.metrics.RootMeanSquaredError()],\n",
    "    )\n",
    "    self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(\n",
    "        metrics=tfrs.metrics.FactorizedTopK(\n",
    "            candidates=movies.batch(128).map(self.movie_model)\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # The loss weights.\n",
    "    self.rating_weight = rating_weight\n",
    "    self.retrieval_weight = retrieval_weight\n",
    "\n",
    "  def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:\n",
    "    # We pick out the user features and pass them into the user model.\n",
    "    user_embeddings = self.user_model(features[\"user_id\"])\n",
    "    # And pick out the movie features and pass them into the movie model.\n",
    "    movie_embeddings = self.movie_model(features[\"movie_title\"])\n",
    "    \n",
    "    return (\n",
    "        user_embeddings,\n",
    "        movie_embeddings,\n",
    "        # We apply the multi-layered rating model to a concatentation of\n",
    "        # user and movie embeddings.\n",
    "        self.rating_model(\n",
    "            tf.concat([user_embeddings, movie_embeddings], axis=1)\n",
    "        ),\n",
    "    )\n",
    "\n",
    "  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n",
    "\n",
    "    user_embeddings, movie_embeddings, rating_predictions = self(features)\n",
    "\n",
    "    # We compute the loss for each task.\n",
    "    rating_loss = self.rating_task(\n",
    "        labels=features[\"user_rating\"],\n",
    "        predictions=rating_predictions,\n",
    "    )\n",
    "    retrieval_loss = self.retrieval_task(user_embeddings, movie_embeddings)\n",
    "\n",
    "    # And combine them using the loss weights.\n",
    "    return (self.rating_weight * rating_loss\n",
    "            + self.retrieval_weight * retrieval_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ngvn-c0b8lc2"
   },
   "source": [
    "### Rating-specialized model\n",
    "\n",
    "Depending on the weights we assign, the model will encode a different balance of the tasks. Let's start with a model that only considers ratings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NNfB6rYL0VrS"
   },
   "outputs": [],
   "source": [
    "model = MovielensModel(rating_weight=1.0, retrieval_weight=0.0)\n",
    "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "I6kjfF1j0iZR"
   },
   "outputs": [],
   "source": [
    "cached_train = train.shuffle(100_000).batch(8192).cache()\n",
    "cached_test = test.batch(4096).cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6NWadH1q0c_T"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "98/98 [==============================] - 90s 921ms/step - root_mean_squared_error: 1.2454 - factorized_top_k: 0.0089 - factorized_top_k/top_1_categorical_accuracy: 1.4500e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0013 - factorized_top_k/top_10_categorical_accuracy: 0.0027 - factorized_top_k/top_50_categorical_accuracy: 0.0136 - factorized_top_k/top_100_categorical_accuracy: 0.0268 - loss: 1.5459 - regularization_loss: 0.0000e+00 - total_loss: 1.5459\n",
      "Epoch 2/3\n",
      "98/98 [==============================] - 86s 878ms/step - root_mean_squared_error: 1.0252 - factorized_top_k: 0.0106 - factorized_top_k/top_1_categorical_accuracy: 2.3500e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0018 - factorized_top_k/top_10_categorical_accuracy: 0.0036 - factorized_top_k/top_50_categorical_accuracy: 0.0164 - factorized_top_k/top_100_categorical_accuracy: 0.0312 - loss: 1.0498 - regularization_loss: 0.0000e+00 - total_loss: 1.0498\n",
      "Epoch 3/3\n",
      "98/98 [==============================] - 86s 881ms/step - root_mean_squared_error: 0.9529 - factorized_top_k: 0.0100 - factorized_top_k/top_1_categorical_accuracy: 2.0125e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0015 - factorized_top_k/top_10_categorical_accuracy: 0.0031 - factorized_top_k/top_50_categorical_accuracy: 0.0153 - factorized_top_k/top_100_categorical_accuracy: 0.0297 - loss: 0.9075 - regularization_loss: 0.0000e+00 - total_loss: 0.9075\n",
      "49/49 [==============================] - 19s 381ms/step - root_mean_squared_error: 0.9390 - factorized_top_k: 0.0097 - factorized_top_k/top_1_categorical_accuracy: 2.1000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0015 - factorized_top_k/top_10_categorical_accuracy: 0.0031 - factorized_top_k/top_50_categorical_accuracy: 0.0148 - factorized_top_k/top_100_categorical_accuracy: 0.0290 - loss: 0.8812 - regularization_loss: 0.0000e+00 - total_loss: 0.8812\n",
      "Retrieval top-100 accuracy: 0.029.\n",
      "Ranking RMSE: 0.939.\n"
     ]
    }
   ],
   "source": [
    "model.fit(cached_train, epochs=3)\n",
    "metrics = model.evaluate(cached_test, return_dict=True)\n",
    "\n",
    "print(f\"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.\")\n",
    "print(f\"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lENViv04-i0T"
   },
   "source": [
    "The model does OK on predicting ratings (with an RMSE of around 1.11), but performs poorly at predicting which movies will be watched or not: its accuracy at 100 is almost 4 times worse than a model trained solely to predict watches."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yPYd9LtE-4Fm"
   },
   "source": [
    "### Retrieval-specialized model\n",
    "\n",
    "Let's now try a model that focuses on retrieval only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BfnkGd2G--Qt"
   },
   "outputs": [],
   "source": [
    "model = MovielensModel(rating_weight=0.0, retrieval_weight=1.0)\n",
    "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JCCBdM7U_B11"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "98/98 [==============================] - 86s 879ms/step - root_mean_squared_error: 3.7939 - factorized_top_k: 0.0462 - factorized_top_k/top_1_categorical_accuracy: 5.0125e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0068 - factorized_top_k/top_10_categorical_accuracy: 0.0151 - factorized_top_k/top_50_categorical_accuracy: 0.0742 - factorized_top_k/top_100_categorical_accuracy: 0.1343 - loss: 71323.9254 - regularization_loss: 0.0000e+00 - total_loss: 71323.9254\n",
      "Epoch 2/3\n",
      "98/98 [==============================] - 86s 882ms/step - root_mean_squared_error: 3.8256 - factorized_top_k: 0.0665 - factorized_top_k/top_1_categorical_accuracy: 8.0500e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0119 - factorized_top_k/top_10_categorical_accuracy: 0.0250 - factorized_top_k/top_50_categorical_accuracy: 0.1084 - factorized_top_k/top_100_categorical_accuracy: 0.1861 - loss: 69115.7197 - regularization_loss: 0.0000e+00 - total_loss: 69115.7197\n",
      "Epoch 3/3\n",
      "98/98 [==============================] - 86s 879ms/step - root_mean_squared_error: 3.8383 - factorized_top_k: 0.0711 - factorized_top_k/top_1_categorical_accuracy: 8.1625e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0136 - factorized_top_k/top_10_categorical_accuracy: 0.0280 - factorized_top_k/top_50_categorical_accuracy: 0.1162 - factorized_top_k/top_100_categorical_accuracy: 0.1969 - loss: 68433.7741 - regularization_loss: 0.0000e+00 - total_loss: 68433.7741\n",
      "49/49 [==============================] - 19s 391ms/step - root_mean_squared_error: 3.8453 - factorized_top_k: 0.0552 - factorized_top_k/top_1_categorical_accuracy: 0.0011 - factorized_top_k/top_5_categorical_accuracy: 0.0085 - factorized_top_k/top_10_categorical_accuracy: 0.0180 - factorized_top_k/top_50_categorical_accuracy: 0.0880 - factorized_top_k/top_100_categorical_accuracy: 0.1605 - loss: 31806.8622 - regularization_loss: 0.0000e+00 - total_loss: 31806.8622\n",
      "Retrieval top-100 accuracy: 0.160.\n",
      "Ranking RMSE: 3.845.\n"
     ]
    }
   ],
   "source": [
    "model.fit(cached_train, epochs=3)\n",
    "metrics = model.evaluate(cached_test, return_dict=True)\n",
    "\n",
    "print(f\"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.\")\n",
    "print(f\"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YjM7j7eY_jPh"
   },
   "source": [
    "We get the opposite result: a model that does well on retrieval, but poorly on predicting ratings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hOFwjUus_pLU"
   },
   "source": [
    "### Joint model\n",
    "\n",
    "Let's now train a model that assigns positive weights to both tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7xyDbNMf_t8a"
   },
   "outputs": [],
   "source": [
    "model = MovielensModel(rating_weight=1.0, retrieval_weight=1.0)\n",
    "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2pZmM_ub_uEO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "72/98 [=====================>........] - ETA: 23s - root_mean_squared_error: 1.3590 - factorized_top_k: 0.0417 - factorized_top_k/top_1_categorical_accuracy: 4.3403e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0060 - factorized_top_k/top_10_categorical_accuracy: 0.0133 - factorized_top_k/top_50_categorical_accuracy: 0.0669 - factorized_top_k/top_100_categorical_accuracy: 0.1216 - loss: 72351.8831 - regularization_loss: 0.0000e+00 - total_loss: 72351.8831"
     ]
    }
   ],
   "source": [
    "model.fit(cached_train, epochs=3)\n",
    "metrics = model.evaluate(cached_test, return_dict=True)\n",
    "\n",
    "print(f\"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.\")\n",
    "print(f\"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ni_rkOsaB3f9"
   },
   "source": [
    "The result is a model that performs roughly as well on both tasks as each specialized model. \n",
    "\n",
    "While the results here do not show a clear accuracy benefit from a joint model in this case, multi-task learning is in general an extremely useful tool. We can expect better results when we can transfer knowledge from a data-abundant task (such as clicks) to a closely related data-sparse task (such as purchases)."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "multitask.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
